{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bs4 import BeautifulSoup\n",
    "import re\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from sklearn.model_selection import KFold\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_string(string):\n",
    "    string = re.sub('[^A-Za-z0-9\\-\\/ ]+', ' ', string).split()\n",
    "    return [y.strip() for y in string]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_raw(filename):\n",
    "    with open(filename, 'r') as fopen:\n",
    "        entities = fopen.read()\n",
    "    soup = BeautifulSoup(entities, 'html.parser')\n",
    "    inside_tag = ''\n",
    "    texts, labels = [], []\n",
    "    for sentence in soup.prettify().split('\\n'):\n",
    "        if len(inside_tag):\n",
    "            splitted = process_string(sentence)\n",
    "            texts += splitted\n",
    "            labels += [inside_tag] * len(splitted)\n",
    "            inside_tag = ''\n",
    "        else:\n",
    "            if not sentence.find('</'):\n",
    "                pass\n",
    "            elif not sentence.find('<'):\n",
    "                inside_tag = sentence.split('>')[0][1:]\n",
    "            else:\n",
    "                splitted = process_string(sentence)\n",
    "                texts += splitted\n",
    "                labels += ['OTHER'] * len(splitted)\n",
    "    assert (len(texts)==len(labels)), \"length texts and labels are not same\"\n",
    "    print('len texts and labels: ', len(texts))\n",
    "    return texts,labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len texts and labels:  34012\n",
      "len texts and labels:  9249\n"
     ]
    }
   ],
   "source": [
    "train_texts, train_labels = parse_raw('data_train.txt')\n",
    "test_texts, test_labels = parse_raw('data_test.txt')\n",
    "train_texts += test_texts\n",
    "train_labels += test_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array(['OTHER', 'location', 'organization', 'person', 'quantity', 'time'],\n",
       "       dtype='<U12'), array([35613,  1536,  1592,  2358,  1336,   826]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(train_labels,return_counts=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('entities-bm-normalize-v3.txt','r') as fopen:\n",
    "    entities_bm = fopen.read().split('\\n')[:-1]\n",
    "entities_bm = [i.split() for i in entities_bm]\n",
    "entities_bm = [[i[0],'TIME' if i[0] in 'jam' else i[1]] for i in entities_bm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'KN'\n",
      "'KA'\n"
     ]
    }
   ],
   "source": [
    "replace_by = {'LOC':'location','PRN':'person','NORP':'organization','ORG':'organization','LAW':'law',\n",
    "             'EVENT':'OTHER','FAC':'organization','TIME':'time','O':'OTHER','ART':'person','DOC':'law'}\n",
    "for i in entities_bm:\n",
    "    try:\n",
    "        string = process_string(i[0])\n",
    "        if len(string):\n",
    "            train_labels.append(replace_by[i[1]])\n",
    "            train_texts.append(process_string(i[0])[0])  \n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        \n",
    "assert (len(train_texts)==len(train_labels)), \"length texts and labels are not same\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array(['OTHER', 'law', 'location', 'organization', 'person', 'quantity',\n",
       "        'time'], dtype='<U12'),\n",
       " array([47406,   107,  2010,  2435,  3913,  1336,  1240]))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(train_labels,return_counts=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "word2idx = {'PAD': 0,'NUM':1,'UNK':2}\n",
    "tag2idx = {'PAD': 0}\n",
    "char2idx = {'PAD': 0}\n",
    "word_idx = 3\n",
    "tag_idx = 1\n",
    "char_idx = 1\n",
    "\n",
    "def parse_XY(texts, labels):\n",
    "    global word2idx, tag2idx, char2idx, word_idx, tag_idx, char_idx\n",
    "    X, Y = [], []\n",
    "    for no, text in enumerate(texts):\n",
    "        tag = labels[no]\n",
    "        for c in text:\n",
    "            if c not in char2idx:\n",
    "                char2idx[c] = char_idx\n",
    "                char_idx += 1\n",
    "        if tag not in tag2idx:\n",
    "            tag2idx[tag] = tag_idx\n",
    "            tag_idx += 1\n",
    "        Y.append(tag2idx[tag])\n",
    "        text = text.lower()\n",
    "        if text not in word2idx:\n",
    "            word2idx[text] = word_idx\n",
    "            word_idx += 1\n",
    "        X.append(word2idx[text])\n",
    "    return X, np.array(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['OTHER', 'law', 'location', 'organization', 'person', 'quantity',\n",
       "       'time'], dtype='<U12')"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(train_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, Y = parse_XY(train_texts, train_labels)\n",
    "idx2word={idx: tag for tag, idx in word2idx.items()}\n",
    "idx2tag = {i: w for w, i in tag2idx.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_len = 50\n",
    "def iter_seq(x):\n",
    "    return np.array([x[i: i+seq_len] for i in range(0, len(x)-seq_len, 1)])\n",
    "\n",
    "def to_train_seq(*args):\n",
    "    return [iter_seq(x) for x in args]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(58397, 50)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_seq, Y_seq = to_train_seq(X, Y)\n",
    "X_seq.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('crf-lstm-bidirectional.json','w') as fopen:\n",
    "    fopen.write(json.dumps({'idx2tag':idx2tag,'idx2word':idx2word,\n",
    "           'word2idx':word2idx,'tag2idx':tag2idx}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.utils import to_categorical\n",
    "Y_seq_3d = [to_categorical(i, num_classes=len(tag2idx)) for i in Y_seq]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.cross_validation import train_test_split\n",
    "train_X, test_X, train_Y, test_Y = train_test_split(X_seq, Y_seq_3d, test_size=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.2.2\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "print(keras.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.models import Model, Input\n",
    "from keras.layers import LSTM, Embedding, Dense, TimeDistributed, Dropout, Bidirectional\n",
    "from keras_contrib.layers import CRF\n",
    "from keras.backend.tensorflow_backend import set_session\n",
    "set_session(tf.InteractiveSession())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_len = seq_len\n",
    "input_word = Input(shape=(None,))\n",
    "model = Embedding(input_dim=len(word2idx) + 1, output_dim=128, mask_zero=True)(input_word)\n",
    "model = Bidirectional(LSTM(units=64, return_sequences=True, recurrent_dropout=0.1))(model)\n",
    "model = TimeDistributed(Dense(50, activation=\"relu\"))(model)\n",
    "crf = CRF(len(tag2idx))\n",
    "out = crf(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(input_word, out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(optimizer=\"rmsprop\", loss=crf.loss_function, metrics=[crf.accuracy])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         (None, None)              0         \n",
      "_________________________________________________________________\n",
      "embedding_1 (Embedding)      (None, None, 128)         1215872   \n",
      "_________________________________________________________________\n",
      "bidirectional_1 (Bidirection (None, None, 128)         98816     \n",
      "_________________________________________________________________\n",
      "time_distributed_1 (TimeDist (None, None, 50)          6450      \n",
      "_________________________________________________________________\n",
      "crf_1 (CRF)                  (None, None, 8)           488       \n",
      "=================================================================\n",
      "Total params: 1,321,626\n",
      "Trainable params: 1,321,626\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg height=\"337pt\" viewBox=\"0.00 0.00 321.00 337.00\" width=\"321pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 333)\">\n",
       "<title>G</title>\n",
       "<polygon fill=\"white\" points=\"-4,4 -4,-333 317,-333 317,4 -4,4\" stroke=\"none\"/>\n",
       "<!-- 140568954872776 -->\n",
       "<g class=\"node\" id=\"node1\"><title>140568954872776</title>\n",
       "<polygon fill=\"none\" points=\"94,-292.5 94,-328.5 219,-328.5 219,-292.5 94,-292.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"156.5\" y=\"-306.8\">input_1: InputLayer</text>\n",
       "</g>\n",
       "<!-- 140568954872944 -->\n",
       "<g class=\"node\" id=\"node2\"><title>140568954872944</title>\n",
       "<polygon fill=\"none\" points=\"76,-219.5 76,-255.5 237,-255.5 237,-219.5 76,-219.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"156.5\" y=\"-233.8\">embedding_1: Embedding</text>\n",
       "</g>\n",
       "<!-- 140568954872776&#45;&gt;140568954872944 -->\n",
       "<g class=\"edge\" id=\"edge1\"><title>140568954872776-&gt;140568954872944</title>\n",
       "<path d=\"M156.5,-292.313C156.5,-284.289 156.5,-274.547 156.5,-265.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"160,-265.529 156.5,-255.529 153,-265.529 160,-265.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140568954886408 -->\n",
       "<g class=\"node\" id=\"node3\"><title>140568954886408</title>\n",
       "<polygon fill=\"none\" points=\"22,-146.5 22,-182.5 291,-182.5 291,-146.5 22,-146.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"156.5\" y=\"-160.8\">bidirectional_1(lstm_1): Bidirectional(LSTM)</text>\n",
       "</g>\n",
       "<!-- 140568954872944&#45;&gt;140568954886408 -->\n",
       "<g class=\"edge\" id=\"edge2\"><title>140568954872944-&gt;140568954886408</title>\n",
       "<path d=\"M156.5,-219.313C156.5,-211.289 156.5,-201.547 156.5,-192.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"160,-192.529 156.5,-182.529 153,-192.529 160,-192.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140568954582744 -->\n",
       "<g class=\"node\" id=\"node4\"><title>140568954582744</title>\n",
       "<polygon fill=\"none\" points=\"0,-73.5 0,-109.5 313,-109.5 313,-73.5 0,-73.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"156.5\" y=\"-87.8\">time_distributed_1(dense_1): TimeDistributed(Dense)</text>\n",
       "</g>\n",
       "<!-- 140568954886408&#45;&gt;140568954582744 -->\n",
       "<g class=\"edge\" id=\"edge3\"><title>140568954886408-&gt;140568954582744</title>\n",
       "<path d=\"M156.5,-146.313C156.5,-138.289 156.5,-128.547 156.5,-119.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"160,-119.529 156.5,-109.529 153,-119.529 160,-119.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140568954894376 -->\n",
       "<g class=\"node\" id=\"node5\"><title>140568954894376</title>\n",
       "<polygon fill=\"none\" points=\"117.5,-0.5 117.5,-36.5 195.5,-36.5 195.5,-0.5 117.5,-0.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"156.5\" y=\"-14.8\">crf_1: CRF</text>\n",
       "</g>\n",
       "<!-- 140568954582744&#45;&gt;140568954894376 -->\n",
       "<g class=\"edge\" id=\"edge4\"><title>140568954582744-&gt;140568954894376</title>\n",
       "<path d=\"M156.5,-73.3129C156.5,-65.2895 156.5,-55.5475 156.5,-46.5691\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"160,-46.5288 156.5,-36.5288 153,-46.5289 160,-46.5288\" stroke=\"black\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>"
      ],
      "text/plain": [
       "<IPython.core.display.SVG object>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.display import SVG\n",
    "from keras.utils.vis_utils import model_to_dot\n",
    "\n",
    "SVG(model_to_dot(model).create(prog='dot', format='svg'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 47301 samples, validate on 5256 samples\n",
      "Epoch 1/2\n",
      "47301/47301 [==============================] - 305s 6ms/step - loss: 0.1326 - acc: 0.9597 - val_loss: 0.0056 - val_acc: 0.9982\n",
      "Epoch 2/2\n",
      "47301/47301 [==============================] - 304s 6ms/step - loss: 0.0041 - acc: 0.9986 - val_loss: 0.0025 - val_acc: 0.9992\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(train_X, np.array(train_Y), batch_size=32, epochs=2,\n",
    "                    validation_split=0.1, verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5840/5840 [==============================] - 12s 2ms/step\n"
     ]
    }
   ],
   "source": [
    "predicted=model.predict(test_X,verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pred2label(pred):\n",
    "    out = []\n",
    "    for pred_i in pred:\n",
    "        out_i = []\n",
    "        for p in pred_i:\n",
    "            p_i = np.argmax(p)\n",
    "            out_i.append(idx2tag[p_i])\n",
    "        out.append(out_i)\n",
    "    return out\n",
    "    \n",
    "pred_labels = pred2label(predicted)\n",
    "test_labels = pred2label(test_Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       OTHER       1.00      1.00      1.00    236927\n",
      "         law       1.00      0.99      1.00       500\n",
      "    location       1.00      1.00      1.00     10216\n",
      "organization       1.00      1.00      1.00     11752\n",
      "      person       1.00      1.00      1.00     19320\n",
      "    quantity       1.00      1.00      1.00      7097\n",
      "        time       1.00      0.99      1.00      6188\n",
      "\n",
      " avg / total       1.00      1.00      1.00    292000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "print(classification_report(np.array(test_labels).ravel(), np.array(pred_labels).ravel()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_weights('crf-lstm-bidirectional.h5')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
